{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "getting_started_jax.ipynb",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BYzpwF2FLeCH"
      },
      "source": [
        "# Approximate Inference in Bayesian Deep Learning: Getting Started in JAX\n",
        "\n",
        "In this colab we will walk you through downloading the data, running your method and generating a submission for our NeurIPS 2021 competition. In this colab we use the JAX framework. For PyTorch see ...\n",
        "\n",
        "Useful references:\n",
        "- [Competition website](https://izmailovpavel.github.io/neurips_bdl_competition/)\n",
        "- [Efficient implementation of several baselines in JAX](https://github.com/google-research/google-research/tree/master/bnn_hmc)\n",
        "- [Submission platform](https://competitions.codalab.org/competitions/33647)\n",
        "\n",
        "\n",
        "## Setting up colab\n",
        "\n",
        "Colab provides an easy-to-use environment for working on the competition with access to free computational resources. However, you should also be able to run this notebook locally after installing the required dependencies. If you use colab, please select a `GPU` runtime type.\n",
        "\n",
        "## Preparing the data\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "EaSX9Gq9jujD",
        "outputId": "f016ce21-6c54-4487-8d2a-9ae6b8dd87a8"
      },
      "source": [
        "# uncomment order to re-install the starter kit\n",
        "# !rm -rf neurips_bdl_starter_kit\n",
        "\n",
        "!git clone https://github.com/izmailovpavel/neurips_bdl_starter_kit\n",
        "! pip install git+https://github.com/deepmind/dm-haiku\n",
        "! pip install optax\n",
        "\n",
        "import sys\n",
        "import math\n",
        "import jax\n",
        "import tensorflow as tf\n",
        "import optax\n",
        "import matplotlib\n",
        "import numpy as onp\n",
        "from jax import numpy as jnp\n",
        "from matplotlib import pyplot as plt\n",
        "\n",
        "sys.path.append(\"neurips_bdl_starter_kit\")\n",
        "import jax_models as models"
      ],
      "execution_count": 1,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Cloning into 'neurips_bdl_starter_kit'...\n",
            "remote: Enumerating objects: 92, done.\u001b[K\n",
            "remote: Counting objects: 100% (92/92), done.\u001b[K\n",
            "remote: Compressing objects: 100% (64/64), done.\u001b[K\n",
            "remote: Total 92 (delta 42), reused 70 (delta 25), pack-reused 0\u001b[K\n",
            "Unpacking objects: 100% (92/92), done.\n",
            "Collecting git+https://github.com/deepmind/dm-haiku\n",
            "  Cloning https://github.com/deepmind/dm-haiku to /tmp/pip-req-build-yquwnma_\n",
            "  Running command git clone -q https://github.com/deepmind/dm-haiku /tmp/pip-req-build-yquwnma_\n",
            "Requirement already satisfied: absl-py>=0.7.1 in /usr/local/lib/python3.7/dist-packages (from dm-haiku==0.0.5.dev0) (0.12.0)\n",
            "Collecting jmp>=0.0.2\n",
            "  Downloading https://files.pythonhosted.org/packages/ff/5c/1482f4a4a502e080af2ca54d7f80a60b5d4735f464c151666d583b78c226/jmp-0.0.2-py3-none-any.whl\n",
            "Requirement already satisfied: numpy>=1.18.0 in /usr/local/lib/python3.7/dist-packages (from dm-haiku==0.0.5.dev0) (1.19.5)\n",
            "Requirement already satisfied: tabulate>=0.8.9 in /usr/local/lib/python3.7/dist-packages (from dm-haiku==0.0.5.dev0) (0.8.9)\n",
            "Requirement already satisfied: typing_extensions in /usr/local/lib/python3.7/dist-packages (from dm-haiku==0.0.5.dev0) (3.7.4.3)\n",
            "Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from absl-py>=0.7.1->dm-haiku==0.0.5.dev0) (1.15.0)\n",
            "Building wheels for collected packages: dm-haiku\n",
            "  Building wheel for dm-haiku (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Created wheel for dm-haiku: filename=dm_haiku-0.0.5.dev0-cp37-none-any.whl size=531840 sha256=9caf0adb4913aafe2a1bc9b71cc6ba10001a493914f96f7b586f71793cc845e1\n",
            "  Stored in directory: /tmp/pip-ephem-wheel-cache-3err8ly5/wheels/97/0f/e9/17f34e377f8d4060fa88a7e82bee5d8afbf7972384768a5499\n",
            "Successfully built dm-haiku\n",
            "Installing collected packages: jmp, dm-haiku\n",
            "Successfully installed dm-haiku-0.0.5.dev0 jmp-0.0.2\n",
            "Collecting optax\n",
            "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/07/48/4f65dbb5ec096917ec039ba2c7eccf97ee05a4157e0e965a45ed3b7a13f9/optax-0.0.9-py3-none-any.whl (118kB)\n",
            "\u001b[K     |████████████████████████████████| 122kB 7.6MB/s \n",
            "\u001b[?25hRequirement already satisfied: numpy>=1.18.0 in /usr/local/lib/python3.7/dist-packages (from optax) (1.19.5)\n",
            "Requirement already satisfied: jax>=0.1.55 in /usr/local/lib/python3.7/dist-packages (from optax) (0.2.17)\n",
            "Requirement already satisfied: jaxlib>=0.1.37 in /usr/local/lib/python3.7/dist-packages (from optax) (0.1.69+cuda110)\n",
            "Requirement already satisfied: absl-py>=0.7.1 in /usr/local/lib/python3.7/dist-packages (from optax) (0.12.0)\n",
            "Collecting chex>=0.0.4\n",
            "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/0f/95/ccd2da57155c019efb3a60e3e5ecb9da431e19ebb16cce1e6981d615d75e/chex-0.0.8-py3-none-any.whl (57kB)\n",
            "\u001b[K     |████████████████████████████████| 61kB 7.0MB/s \n",
            "\u001b[?25hRequirement already satisfied: opt-einsum in /usr/local/lib/python3.7/dist-packages (from jax>=0.1.55->optax) (3.3.0)\n",
            "Requirement already satisfied: flatbuffers<3.0,>=1.12 in /usr/local/lib/python3.7/dist-packages (from jaxlib>=0.1.37->optax) (1.12)\n",
            "Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from jaxlib>=0.1.37->optax) (1.4.1)\n",
            "Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from absl-py>=0.7.1->optax) (1.15.0)\n",
            "Requirement already satisfied: toolz>=0.9.0 in /usr/local/lib/python3.7/dist-packages (from chex>=0.0.4->optax) (0.11.1)\n",
            "Requirement already satisfied: dm-tree>=0.1.5 in /usr/local/lib/python3.7/dist-packages (from chex>=0.0.4->optax) (0.1.6)\n",
            "Installing collected packages: chex, optax\n",
            "Successfully installed chex-0.0.8 optax-0.0.9\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mrcvRuuJRGhZ"
      },
      "source": [
        "We provide the datasets used in this competition in a public Google Cloud Storage bucket in the `.csv` format. Here we download the data:\n",
        "\n",
        "You can also download the data to your computer by clicking these links:\n",
        "- [CIFAR-10 train features](https://storage.googleapis.com/neurips2021_bdl_competition/cifar10_train_x.csv)\n",
        "- [CIFAR-10 train labels](https://storage.googleapis.com/neurips2021_bdl_competition/cifar10_train_y.csv)\n",
        "- [CIFAR-10 test features](https://storage.googleapis.com/neurips2021_bdl_competition/cifar10_test_x.csv)\n",
        "- [CIFAR-10 test labels](https://storage.googleapis.com/neurips2021_bdl_competition/cifar10_test_y.csv)"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "da-9ecSDnKO9",
        "outputId": "54d5cdd7-74b9-4b05-c7a5-88a7135e045b"
      },
      "source": [
        "!gsutil -m cp -r gs://neurips2021_bdl_competition/cifar10_*.csv .\n",
        "!gsutil -m cp -r gs://neurips2021_bdl_competition/imdb_*.csv ."
      ],
      "execution_count": 2,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Copying gs://neurips2021_bdl_competition/cifar10_test_x.csv...\n",
            "/ [0/4 files][    0.0 B/  4.4 GiB]   0% Done                                    \rCopying gs://neurips2021_bdl_competition/cifar10_test_y.csv...\n",
            "/ [0/4 files][    0.0 B/  4.4 GiB]   0% Done                                    \rCopying gs://neurips2021_bdl_competition/cifar10_train_x.csv...\n",
            "Copying gs://neurips2021_bdl_competition/cifar10_train_y.csv...\n",
            "/ [4/4 files][  4.4 GiB/  4.4 GiB] 100% Done  64.5 MiB/s ETA 00:00:00           \n",
            "Operation completed over 4 objects/4.4 GiB.                                      \n",
            "Copying gs://neurips2021_bdl_competition/imdb_test_x.csv...\n",
            "Copying gs://neurips2021_bdl_competition/imdb_test_y.csv...\n",
            "Copying gs://neurips2021_bdl_competition/imdb_train_x.csv...\n",
            "Copying gs://neurips2021_bdl_competition/imdb_train_y.csv...\n",
            "|\n",
            "Operation completed over 4 objects/120.4 MiB.                                    \n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zDsZHMQzSGpZ"
      },
      "source": [
        "We can now read the data and convert it into numpy arrays. This cell may take several minutes to run."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 303
        },
        "id": "4o8NLOGgnVSV",
        "outputId": "664d281f-5ca5-45f8-e9dd-61892980b143"
      },
      "source": [
        "x_train = onp.loadtxt(\"cifar10_train_x.csv\")\n",
        "y_train = onp.loadtxt(\"cifar10_train_y.csv\")\n",
        "x_test = onp.loadtxt(\"cifar10_test_x.csv\")\n",
        "y_test = onp.loadtxt(\"cifar10_test_y.csv\")\n",
        "\n",
        "x_train = x_train.reshape((len(x_train), 32, 32, 3))\n",
        "x_test = x_test.reshape((len(x_test), 32, 32, 3))\n",
        "\n",
        "plt.imshow(x_train[0])\n",
        "plt.colorbar()"
      ],
      "execution_count": 3,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "<matplotlib.colorbar.Colorbar at 0x7f61d35047d0>"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 3
        },
        {
          "output_type": "display_data",
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAS4AAAD8CAYAAADJwUnTAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAfAUlEQVR4nO3df5yWdZ3v8ddHGnVRw2hYIpR0kY04UWiTuEkWmkTiKu2Gqbsd2lRalGM/tHM4dkSzHkfNzKPnGC0qK/YQDfpBnKDEFPWwuxFIBIKooIawo0AqWbPmMPM5f1zX1D0z1/d7XzNzz33f1/B+Ph7Xg5nv5/rx5Z7hw3Vd31/m7oiIFMkhta6AiEhPKXGJSOEocYlI4ShxiUjhKHGJSOEocYlI4ShxiUi/MbOFZrbHzJ4IxM3MbjOz7Wa2ycxOynNeJS4R6U93A1Mj8Y8BY9JtFjA/z0mVuESk37j7Y8DLkV3OBe7xxM+Bo81sRLnzvqlSFczjcDM/IhBrjxz3RqDcIse0RmIHIrHYBxKqh0g9cffYP42yPjr5CP/Ny2259n180x+2AK+XFC1w9wU9uNxI4IWS73elZc2xg/qUuMxsKnArMAi4091viO1/BOF7xpbIcTsD5Q2RY2J/632RWGMv6iEykPzm5TZ+8cCoXPsOGvHM6+7e1M9V6qbXicvMBgG3A2eSZMl1Zrbc3bdWqnIiUn0OtEefgSpqN3BsyffHpGVRfXnHdTKw3d2fdfc3gPtJnldFpMAcp9Xbcm0VsBz4z2nr4inAfnePPiZC3x4Vs55NJ3bdycxmkbQWMLgPFxOR6qnUHZeZ3Qd8GGg0s13ANaRvedz928BK4CxgO8kbo3/Ic95+fzmfvqhbAPBWM82hI1LnHKetQtNdufsFZeIOXNbT8/YlcfXq2VRE6l879X2P0ZfEtQ4YY2bHkySs84ELYwfsB1YEYrEDb/3CVzLLdzSHH4W3Lft2+ISvh0OxlsoNgfJY14vY43GslXJZJCbSnxxoG6iJy90PmNkc4AGS7hAL3X1LxWomIjUzkO+4cPeVJC/XRGSAcKC1zqd0r2rPeRGpf44P3EdFERmgHNrqO28pcYlIZ0nP+fqmxCUiXRht0SkMas+qua6i9bIDqv/sD9mBSF8D+8BhwVisy8Pvf/K7YCzU/eJdn3lf5Izh4ePfmnBdMHbpxquCsbGRq108PPv/ohlDwse84+nYfBlhU3hzMLaB32aWxwa4H6xCP5r9vTxfX2eHePd7DvUlK4bl2vc/jfr3xws1yFpEBqakH1d933EpcYlIN+19u2nrd0pcItKJ7rhEpHAco63OZ3VX4hKRbvSoWAEf/Mh5meUT3z+54tdq2Rlu91q7IjT0ObsFrZyG/aFh2/GWw/GcHD5u9tWZ5aPGTwoe44sWB2O3rbgzGLv4G9cHYw2jRmcHWiNtupH5uz995eeCscUv/yh8YJ0YH4n9v6HZ5TsjS0yEBuh/Pm+FIhzjDR9UgTP1n0IkLhGpnqQDqh4VRaRg9HJeRArF3Whz3XGJSMG0645LRIokeTlf36mhvmsnIlVXhJfzhRhkLX3ng24KBy+cEo5NHxGO7Q93HVm9M7vBfvPmHcFjThof7jQwf354kfTFL9XHJLyxLizTOq0r09k3tvw4OzA4sq764LdnFjdNaWL9xvV9es47Yfxg//qyd+ba929P2KhB1iJSe+o5LyKF1K5WRREpkmSQtRKXiBSIY7RqyI+IFIk76oAqIkVj6oAqdaLtS+FYY6Q7xN9E5h7/t3Bs8dLsmTTu/MlXg8fMXTs9GGt+aXW4HnViWyQ2JDbb/qrsWMvo8EIBDaOyy701UomcnAF+x2VmzwOvAW3AgVr05xCRyjsYXs5Pdnct3iIyQDimiQRFpFgcaK3zsYp9vR90YJWZPW5ms7J2MLNZZrbezNb38VoiUhXJgrB5tlrpa1qd5O67zezPgQfNbJu7P1a6g7svABaAxiqKFIFT/z3n+1Q7d9+d/rkH+CFEJkMXkcIYsHdcZnYEcIi7v5Z+PQUIrykPHMpRvI1TMmM7iTV3926JeMlpW3jRDnhPMHLnDfcGYzsDi46MJtxev3bXfcFYBVr5+93gXsb278teQGRbQ3P4oJbsY1pe7/sn5W4VveMys6nArcAg4E53v6FLfBSwCDg63Weuu0en/OjLo+Jw4Idm1nGexe7+0z6cT0TqQPJyvjJDfsxsEHA7cCawC1hnZsvdfWvJbv8DWOLu881sHLASOC523l4nLnd/Fnhvb48XkXpV0TnnTwa2p/kCM7sfOBcoTVwOvDn9egjw7+VOWt9tniJSdcnL+dzvrxq79BhYkDbIdRgJvFDy/S5gYpdzXEvSO+G/AEcAHyl3USUuEemmBz3n91VgxMwFwN3ufrOZ/RXwHTN7t7u3hw5Q4hKRTircc343dJq3+pi0rNRFwFQAd/83MzscaAT2hE5a3501RKQm2jkk15bDOmCMmR1vZocC5wPLu+yzEzgDwMzeBRwO7I2dtKp3XO/E+XGgYbuFk4LHjSYwFD7SyLyCe4Kx7HkLOuoRFlo2Ym3kmFgsJtZ8Hqvjw725WGNkUYaIHcsXB2MTj8r+tFZFuraEl9GoH6MjsezlQRLj+Y9gbPUN8zLLNw8J/6THTsme0eONV6L/3nNxh9b2ytzTuPsBM5sDPEDS1WGhu28xs+uA9e6+HLgCuMPMvkDyiu3TXmYVHz0qikgnyaNi5R7G0j5ZK7uUzSv5eitwak/OqcQlIt3Usld8HkpcItJJD7tD1IQSl4h0UdlHxf6gxCUi3WjO+RKHvu+djFrfi7nDHwqMANgRboeaviK8IPr05UsjF9sfiWUPbI21J22OtCbF5iiPxa4e+t/CwSWBpeo/EmqZBe45Oxz7ypPBULgdGBpei32OPTc5Eov9RoXaS+f28nzhWeC7dwcvFWtxvK3tkczywS+Hj1lz/y8yyysxFXHSqqjlyUSkQDR1s4gUkh4VRaRQ1KooIoWkVkURKRR344ASl4gUjR4VK+GMt/esHGDWByMn/O99qk5PjO9l7O0W/sWZ9/KNwdi3LslusJ/9XLhB/qOWPWAXYBUPBmN+yzPBGK3Z86X72knhY9ZG5r7ftTkc45fBSKhTRgNvCx4zgxeDsUinkj54c6A81vkiexB7E0/0uTZ6xyUihaTEJSKFon5cIlJI6sclIoXiDgcqNJFgf1HiEpFu9KgoIoWid1x1qvVX4dj+yKQGjadVuCJd1zopEVl8PerS567KLj8+u7xPJp0QjjWFYrFuKpUX61AQ0j9dHqqkqa8rhSW8zhNX2QdZM1toZnvM7ImSsqFm9qCZPZP++Zb+raaIVFM7lmurlTxv4O4mXfOsxFzgIXcfAzxEfHojESkQ9+QdV56tVsomLnd/DOg6pdm5wKL060XA9ArXS0RqxmhrPyTXViu9fcc13N07XsO8CAwP7Whms4BZAKNGFfrtgchBo/DvuMpJF24MLt7o7gvcvcndm4YNG9bXy4lIP+sYq1joR8WAl8xsBED6557KVUlEasqT91x5tlrp7aPicmAmcEP6548qVqNKWR/+VC85+7xgbOdL4Y4IjUdlL70wedq04DEXXj4zGPvoByLdCQpgxftDi4fAtDnXZQf+d/Vm5pDeK/yQHzO7D/gw0Ghmu4BrSBLWEjO7CPg1EM4EIlIonr6cr2dlE5e7XxAInVHhuohInajlY2AeB2XPeRGJq/dWRSUuEekkefGuxCUiBaNB1iJSOHrH1Y/mf/FbwdiMSBeFiZPCy1Qs+v73whd8LVC+JtwtYPZ9lwRjMz80ORhb++g9wdj1wUg1lwGBaR+bHYztn784s3zx/wnPUjF7Y+Rfy3tzV6sTs9DPOryohNf7v9p+5hjtRW9VFJGDT72n7vpOqyJSfenL+TxbHmY21cyeMrPtZpY5k4yZnWdmW81si5ll366X0B2XiHRXoVsuMxsE3A6cCewC1pnZcnffWrLPGJK3HKe6+ytm9uflzqs7LhHppoJ3XCcD2939WXd/A7ifZFqsUpcAt7v7K8m1vezYZyUuEenEgfZ2y7WRDAVcX7LN6nK6kcALJd/vSstK/SXwl2b2L2b2czPrOnFpN3pUFJHOHMjfj2ufu/d1ovs3AWNIxkQfAzxmZuPd/dXYAXVvzT8/mVl+6S2XBY+ZfcWlkdgVwdil378mf8VSS3dFulBEzJh+Ybgeke4QG3p1td65PBacODYYGjL76szyeeeEX18smhD+xzLq+E8EY0ufWxaMwYFILJtZuB4HS1eJCv41dwPHlnx/DN2XidkFrHX3VuA5M3uaJJGtC51Uj4oi0p3n3MpbB4wxs+PN7FDgfJJpsUotI7nbwswaSR4dn42dtBB3XCJSTfm7OpTj7gfMbA7wADAIWOjuW8zsOmC9uy9PY1PMbCvQBnzJ3X8TO68Sl4h0V8EnYndfCazsUjav5GsHvphuuShxiUhnDt6uQdYiUjhKXH328c98vOcHDY7FjgiGJnJqMLaWf8ksb4h8jDt/kN0iCrCzOTy/fczSXh3VO7EWzM3Xfi4YG//HZTc72xc5361Dwy2Hy57bHIxNIbzs3ar4O14JqfPG00IkLhGpMiUuESmUnnVArQklLhHppt772SpxiUh3alUUkaIx3XGJSKHkH85TM4VIXOOPyZ43fPWup4LHHDF0SDA2e0J4cPOooY3B2NqXs8tbIwN5V61aE4zNnDYlGBvy9WCI/eFQxY2OxFZH/t4tQ7PLfz76H4PHTLx3fvhil8wMhhY/WnbCzIpZ893fB2OTPhnuZlMsVvcv58sOsjazhWa2x8yeKCm71sx2m9nGdDurf6spIlVVuUHW/SLP7BB3A1kTe93i7hPSbWVGXESKqj3nViNlHxXd/TEzO67/qyIidaEA/bj6Mh/XHDPblD5KviW0k5nN6pjWde/evX24nIhUi3m+rVZ6m7jmk7y3nQA0AzeHdnT3Be7e5O5Nw4YN6+XlRKSqBsA7rm7c/SV3b3P3duAOkpU8RESqolfdIcxshLt3TG3wcWLrmVfAwy9kz4dw+tumB49paAmfb/So0LLsMGPGjGCs9YbsjgjLXnskeMy2DTuCsYaZ4VkNXv2nrcHY6Z8dF4ytDkZ654r3/89gbFtz+GotMydllk/+2rzM8nIaCHdTWdGLeeV7a+3qtcHYpE+eXrV69LfCd0A1s/tI5oNuNLNdwDXAh81sAsnN4vPAZ/uxjiJSTU7xh/y4+wUZxXf1Q11EpF4U/Y5LRA4+hX9UFJGDkBKXiBSOEpeIFEmtO5fmUejE9fCLsaXXw7566jeCseZ9O4Oxu+/9Vmb50eeEuyfs2xGey2HbsnB3grGzw03rkyM/ttiMDb2xoiVc//GjwnNH7AwsVtLyyqvBY1qWhruOnPfoN4OxmPGdVn//k8280Kvz7d8Z/v0YUIreqigiBx/dcYlI8ShxiUih6B2XiBSSEpeIFI3VcJLAPPoyH5eISE0clHdc8/71S8HYvWP/VzA2ZMq7MstvPSc8g8KIIeGZKHZs2ByMtVy5IXxcFWdDaG1tCMZGDR4RjN15W/YCFhtWhP9et/3rj/JXLKdNnt194YQ3TQ4es6PtkWBs/Ojw33lA0aOiiBSKXs6LSCEpcYlI4ShxiUiRGGpVFJGiybnCT973YGY21cyeMrPtZjY3st/fmpmbWVO5cx6Ud1zhdjKYNDnc2hQ68OLZFwcPmX/zbcHYtjVrgrHm15uDsRXBSOXNe/prwVgr/xiMXX159oD0DWvDrYrQu1bFJeff1ONjbv3G7GDs7C88EozNmDGxx9cqpAo9KprZIOB24ExgF7DOzJa7+9Yu+x0FfA4IT+pfQndcItJd5ZYnOxnY7u7PuvsbwP3AuRn7fRW4EXg9z0mVuESkmx48KjZ2LPicbrO6nGokdJpDaFda9qdrmZ0EHOvuuR8mDspHRREpI/+j4j53L/tOKsTMDgG+CXy6J8cpcYlIZ17RVsXd0Gk2x2PSsg5HAe8GHjEzgLcBy83sHHdfHzqpEpeIdFe5flzrgDFmdjxJwjofuPCPl3HfD39a7dfMHgGujCUt0DsuEclQqe4Q7n4AmAM8ADwJLHH3LWZ2nZmd09v65VnJ+ljgHmA4SR5e4O63mtlQ4LvAcSSrWZ/n7q/0tiLV9IZXtlvw4LHDgrEN28LN/82vh+cvbyY813vMhZyZWb6KB4PH7OvVleDi668Pxkb9zdGZ5WN/cFLwmEuWhwe/x0yaNr3Hx0z7/HnB2JQvfDJysey/14BTwX8i7r4SWNmlbF5g3w/nOWeeO64DwBXuPg44BbjMzMYBc4GH3H0M8FD6vYgUXd6uEDUcFlQ2cbl7s7tvSL9+jeR2byRJX4xF6W6LgJ7/tycidceobM/5/tCjl/NmdhxwIknv1uHu3tG9+0WSR0kRGQDqfVqb3C/nzexI4PvA5939t6Uxdw/eOJrZrI7OaXv37u1TZUWkSor+qAhgZg0kSeted/9BWvySmY1I4yOAPVnHuvsCd29y96Zhw8IvsUWkjhQ9cVnSK+wu4El3L11OeDkwM/16Jr0dISsi9aXCs0P0hzzvuE4FPgVsNrONadlVwA3AEjO7CPg1EG5fHuBWL3s4GJs2I9xmMXnU2GDs9C+EP84G3hyM3btxVXZgdPAQmje8GowNHhxu/h/Si4Eeg6dVvjvBvBvuDMbu+Psbeny+B372f8PBg6XnY52/4yqbuNx9DUlDQ5YzKlsdEakH9T6RoIb8iEg39d6qqMQlIp3V+MV7HkpcItKdEpeIFElHz/l6psQlIt1Ye31nLiWuHlhx048zyxctzl5uHmDJL8OxmMuXXRGMXfpoeBaFfYEFPRqPDF9rxGlVnPHgsMqf8s4tNwZjd9Dz7hCccXYfajMA6B2XiBSRHhVFpHiUuESkaHTHJSLFo8QlIoVS2VV++oUSl4h0on5cA8zZ//WvM8s3/eSZil9r9nVXBmOXfijcHWLzjuyZHiaPO0gWeehi6YLHM8tnzHpflWtSMBVeUKbSlLhEpBvdcYlIsagDqogUkV7Oi0jhKHGJSLE4ejlfNBt++mSPjxk/9YTKV+S03h22ZumKzPLJf/13fahM7V3+me8GY7ctDP/dVqzK/jzUqhinl/MiUjxKXCJSJOqAKiLF466JBEWkgOo7bylxiUh3elQUkWJxoOiPimZ2LHAPMJzkr7TA3W81s2uBS4C96a5XufvK/qpotexrCUzaDjy8vPKDqXvn2GBkxZq1meVXUx/dIbY99PtgbPzws4KxW+86LxibPmNaMLZqRXZ3CCmjvvMWh+TY5wBwhbuPA04BLjOzcWnsFnefkG6FT1oikjDPt+U6l9lUM3vKzLab2dyM+BfNbKuZbTKzh8zsHeXOWTZxuXuzu29Iv34NeBIYma/KIlJE1u65trLnMRsE3A58DBgHXFBy49Phl0CTu78H+B7w9XLnzXPHVVqJ44ATgY7nkTlpllxoZm/pyblEpE55D7byTga2u/uz7v4GcD9wbqfLua9295b0258Dx5Q7ae7EZWZHAt8HPu/uvwXmA6OBCUAzcHPguFlmtt7M1u/duzdrFxGpI0kHVM+1AY0d/77TbVaX040EXij5fhfxJ7aLgJ+Uq2OuVkUzayBJWve6+w8A3P2lkvgdQOZqqe6+AFgA0NTUVOev/EQEgPyzQ+xz96ZKXNLM/h5oAj5Ubt88rYoG3AU86e7fLCkf4e7N6bcfB57oXXVFpN5Y5WaH2E3nZvBj0rLO1zP7CPBl4EPu/odyJ81zx3Uq8Clgs5ltTMuuInnJNoHkSfd54LM5zlX35t92ZzD2w0d6sZx7P1jyk/AMFn93Xn10ewj54EdPD8b2HsjuylHO5KlHBGNrVzf26pwHtcrOgLoOGGNmx5MkrPOBC0t3MLMTgX8Cprr7njwnLZu43H0NyWNvV+r+IDIgVW6sorsfMLM5wAPAIGChu28xs+uA9e6+HLgJOBJYmjzgsdPdz4mdVz3nRaS7Ck4kmPbxXNmlbF7J1x/p6TmVuESkMy0IKyKFpKmbRaRw6jtvKXGJSHfWXt/PikpcXWx+9LZItD66Q8yINP+f99qPqliTsNU3ZfZHprFtc+Uv9rtwaOnXL8ksn3vjjsrXY6BwetIBtSaUuESkE8Mr2QG1XyhxiUh3SlwiUjhKXCJSKHrHJSJFpFZFESkY16NiZ23Aq4HYvshxg7OLX4kcs21nMDTv7PDCC438R/icP707u7wlfC0aA3UHGBJemIOW8HFLF/V8AYj5J44KxmbPvTgYa90X/oxb9g8Jxi758tcyy69+f3ihDx66PhjavyHcjWL+/GXB2LbQz/OfvxWux8zJ4dghkZ8ZsZkojo7EeqMfE4ujxCUiBVTfT4pKXCLSnfpxiUjxKHGJSKG4Q1t9PysqcYlId7rjEpHCUeIq8eLTcNNHs2Mt2cUA7MwO7tvRnFkOsPjR3wRj4cbzeKeM8z72D5nl4U4BycKTwdjwcKwl0up+267ISQOaN74QjC2+5JpgbEikhX/Hc5FYoHzZunA9Rlx6VTC2P/L7MTjyedw9NBDYsSp80M1Lw7EhkYo0Rn4TRoW7o4SvFTnfmLGBQOw3OCcHKjTnfH/RHZeIdOHgesclIkXi6OW8iBSQ3nGJSOEocYlIsQyAQdZmdjjwGHBYuv/33P2adEnt+4G3Ao8Dn3L3N+JnewMIDEhuDrfWtAZaFZc9eiB4TKwN5+EPhGM7I+Old4ZarwaFjxkxIhxriIy/Hhw57upIi2NzoKH1wsi44VjLIUPeGo61hg+cvuap7EMilxodaigrZ8rbIsFAHYdEahJrJh4b+cEQG4AdaY0M/SJEfj/4XaD1sD38byI3B+p8WptDcuzzB+B0d38vMAGYamanADcCt7j7CcArwEX9V00RqSr3fFuNlE1cnuhYR6Uh3Rw4HfheWr4ImN4vNRSRKkuH/OTZaiTPHRdmNsjMNgJ7gAdJ+he+6u4d96W7gJH9U0URqSoH9/ZcW63kejnv7m3ABDM7GvghkPtthJnNAmYBjDo6V54UkVqr857zPcok7v4qsBr4K+BoM+tIfMcAuwPHLHD3JndvGnaEEpdIIRT9HZeZDUvvtDCzPwPOBJ4kSWCfSHebCdTHEsoi0jfuSatinq1G8jwqjgAWmdkgkkS3xN1/bGZbgfvN7GvAL4G7yp7pyD+DieOzY0P2Bw9rCDRBXzwisoz6+MB1ACaGn3Qbd24Lxk5qCfQ1aI2NAA6Hok3kg8MdOsY2Rp7UdwbqGJvfPtb8H23iD3cNGDU58PNsCP+cGRyJtUS6L4yfEo4NCdRxf2QwcqweI2MdbWKdPWK/CKHPOPZGJnC+Q+6LHNMDRe/H5e6bgBMzyp8FTu6PSolILTne1lbrSkSp57yIdKZpbUSkkOp8Whs184lIJw54u+fa8jCzqWb2lJltN7O5GfHDzOy7aXytmR1X7pxKXCLSmacTCebZykgb9W4HPgaMAy4ws3FddrsIeCUdPngLyXDCKCUuEenG29pybTmcDGx392fTSRjuB87tss+5JMMGIRlGeIaZWeyk5lVs9jSzvcCv028bqcgE2X2menSmenRWtHq8w92H9eVCZvZTgtNqdHM48HrJ9wvcfUHJuT4BTHX3i9PvPwVMdPc5Jfs8ke6zK/1+R7pP8O9b1ZfzpR+oma1396ZqXj+L6qF6qB6dufvUalynL/SoKCL9aTdwbMn3WcMD/7hPOoxwCBBepgslLhHpX+uAMWZ2vJkdCpwPLO+yz3KSYYOQDCN82Mu8w6plP64F5XepCtWjM9WjM9WjD9z9gJnNAR4gmSt4obtvMbPrgPXuvpxkuOB3zGw78DJJcouq6st5EZFK0KOiiBSOEpeIFE5NEle5IQBVrMfzZrbZzDaa2foqXnehme1J+690lA01swfN7Jn0z7fUqB7Xmtnu9DPZaGZnVaEex5rZajPbamZbzOxzaXlVP5NIPar6mZjZ4Wb2CzP7VVqPr6Tlx6dDYranQ2QO7c961DV3r+pG8oJuB/AXwKHAr4Bx1a5HWpfngcYaXPc04CTgiZKyrwNz06/nAjfWqB7XAldW+fMYAZyUfn0U8DTJ8JCqfiaRelT1MwEMODL9ugFYC5wCLAHOT8u/Dcyu5s+pnrZa3HHlGQIwoLn7YyStJ6VKhz1UZdWkQD2qzt2b3X1D+vVrJDPsjqTKn0mkHlXlCa2sFVGLxDUSeKHk+1quEOTAKjN7PF3Uo5aGu3vH9KUvAsNrWJc5ZrYpfZTs90fWUunMACeS3GXU7DPpUg+o8meilbXiDvaX85Pc/SSSkeuXmdlpta4QJP/jkiTVWpgPjCZZ/LcZuLlaFzazI4HvA59399+Wxqr5mWTUo+qfibu3ufsEkp7mJ9ODlbUOBrVIXHmGAFSFu+9O/9xDsuxaLaeifsnMRgCkf+6pRSXc/aX0H007cAdV+kzMrIEkWdzr7j9Ii6v+mWTVo1afSXrtHq+sdTCoReLKMwSg35nZEWZ2VMfXwBTgifhR/ap02EPNVk3qSBSpj1OFzySdwuQu4El3/2ZJqKqfSage1f5MTCtrlVeLFgHgLJIWmx3Al2tUh78gadH8FbClmvUA7iN55GgleVdxEfBW4CHgGeBnwNAa1eM7wGZgE0niGFGFekwieQzcBGxMt7Oq/ZlE6lHVzwR4D8nKWZtIkuS8kt/ZXwDbgaXAYdX6na23TUN+RKRwDvaX8yJSQEpcIlI4SlwiUjhKXCJSOEpcIlI4SlwiUjhKXCJSOP8fV4asJwmR+wEAAAAASUVORK5CYII=\n",
            "text/plain": [
              "<Figure size 432x288 with 2 Axes>"
            ]
          },
          "metadata": {
            "tags": [],
            "needs_background": "light"
          }
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "z8JIisIsSeOO"
      },
      "source": [
        "Finally, we define a `tf.data.Dataset` for the train and test datasets."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Iev-FreJoGBt"
      },
      "source": [
        "train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))\n",
        "test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))"
      ],
      "execution_count": 4,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ui5dKavnSoQ9"
      },
      "source": [
        "## Model and losses\n",
        "\n",
        "We provide the code for all the models used in the competition in the `neurips_bdl_starter_kit/jax_models.py` module. Here, we will load a ResNet-20 model with filter response normalization (FRN) and swish avtivations. The models are implemented in [`haiku`](https://github.com/deepmind/dm-haiku).\n",
        "\n",
        "We also define the cross-entropy likelihood (`log_likelihood_fn`) and Gaussian prior (`log_prior_fn`), and the corresponding posterior log-density (`log_posterior_fn`). The `log_posterior_wgrad_fn` computes the posterior log-density and its gradients with respect to the parameters of the model.\n",
        "\n",
        "The `evaluate_fn` function computes the accuracy and predictions of the model on a given dataset; we will use this function to generate the predictions for our submission."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "KZAbsDh-tNji"
      },
      "source": [
        "net_apply, net_init = models.get_model(\"resnet20_frn_swish\", data_info={\"num_classes\": 10})\n",
        "prior_variance = 5.\n",
        "\n",
        "def log_likelihood_fn(params, batch, is_training=True):\n",
        "    \"\"\"Computes the log-likelihood.\"\"\"\n",
        "    x, y = batch\n",
        "    logits = net_apply(params, None, x, is_training)\n",
        "    num_classes = logits.shape[-1]\n",
        "    labels = jax.nn.one_hot(y, num_classes)\n",
        "    softmax_xent = jnp.sum(labels * jax.nn.log_softmax(logits))\n",
        "\n",
        "    return softmax_xent\n",
        "\n",
        "\n",
        "def log_prior_fn(params):\n",
        "    \"\"\"Computes the Gaussian prior log-density.\"\"\"\n",
        "    n_params = sum([p.size for p in jax.tree_leaves(params)])\n",
        "    exp_term = sum(jax.tree_leaves(jax.tree_map(\n",
        "        lambda p: (-p**2 / (2 * prior_variance)).sum(), params)))\n",
        "    norm_constant = -0.5 * n_params * jnp.log((2 * math.pi * prior_variance))\n",
        "    return exp_term + norm_constant\n",
        "\n",
        "\n",
        "def log_posterior_fn(params, batch, is_training=True):\n",
        "    log_lik = log_likelihood_fn(params, batch, is_training=True)\n",
        "    log_prior = log_prior_fn(params)\n",
        "    return log_lik + log_prior\n",
        "    # return log_lik\n",
        "\n",
        "\n",
        "@jax.jit\n",
        "def get_accuracy_fn(batch, params):\n",
        "    x, y = batch\n",
        "    logits = net_apply(params, None, x, False)\n",
        "    probs = jax.nn.softmax(logits, axis=1)\n",
        "    preds = jnp.argmax(logits, axis=1)\n",
        "    accuracy = (preds == y).mean()\n",
        "    return accuracy, probs\n",
        "\n",
        "\n",
        "def evaluate_fn(dataset, params):\n",
        "    sum_accuracy = 0\n",
        "    all_probs = []\n",
        "    for x, y in dataset:\n",
        "        x, y = jnp.asarray(x), jnp.asarray(y)\n",
        "        batch_accuracy, batch_probs = get_accuracy_fn((x, y), params)\n",
        "        sum_accuracy += batch_accuracy.item()\n",
        "        all_probs.append(onp.asarray(batch_probs))\n",
        "    all_probs = jnp.concatenate(all_probs, axis=0)\n",
        "    return sum_accuracy / len(dataset), all_probs\n",
        "\n",
        "\n",
        "log_posterior_wgrad_fn = jax.jit(jax.value_and_grad(log_posterior_fn, argnums=0))"
      ],
      "execution_count": 5,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WmNpx-CMTyyM"
      },
      "source": [
        "Now we can initialize the parameters of the model. "
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "4-AtnpOethf_"
      },
      "source": [
        "seed = jax.random.PRNGKey(0)\n",
        "key, net_init_key = jax.random.split(seed, 2)\n",
        "init_data, _ = next(iter(train_dataset))\n",
        "init_data = jnp.asarray(init_data)\n",
        "params = net_init(net_init_key, init_data, True)"
      ],
      "execution_count": 6,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KEZnIWNmUEwE"
      },
      "source": [
        "## Optimizer\n",
        "\n",
        "In this colab we will train an approximate maximum-a-posteriori (MAP) solution as our submission for simplicity. You can find efficient implementations of more advanced baselines [here](https://github.com/google-research/google-research/tree/master/bnn_hmc).\n",
        "\n",
        "We will use SGD with momentum. You can adjust the hyper-parameters or switch to a different optimizer by changing the code below.\n",
        "\n",
        "`update_fn` implements a single optimization step on a mini-batch of data."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "fLGw_AVJtjy3"
      },
      "source": [
        "batch_size = 100\n",
        "test_batch_size = 1000\n",
        "num_epochs = 5\n",
        "shuffle_buffer_size = 1000\n",
        "momentum_decay = 0.9\n",
        "init_lr = 1.e-3\n",
        "\n",
        "train_dataset_batched = train_dataset.shuffle(shuffle_buffer_size, reshuffle_each_iteration=True)\n",
        "train_dataset_batched = train_dataset_batched.batch(batch_size)\n",
        "test_dataset_batched = test_dataset.batch(test_batch_size)\n",
        "\n",
        "epoch_steps = len(train_dataset_batched)\n",
        "total_steps = epoch_steps * num_epochs\n",
        "\n",
        "\n",
        "\n",
        "def lr_schedule(step):\n",
        "    t = step / total_steps\n",
        "    return 0.5 * init_lr * (1 + jnp.cos(t * onp.pi))\n",
        "\n",
        "optimizer = optax.chain(\n",
        "    optax.trace(decay=momentum_decay, nesterov=False),\n",
        "    optax.scale_by_schedule(lr_schedule))\n",
        "opt_state = optimizer.init(params)\n",
        "\n",
        "\n",
        "@jax.jit\n",
        "def update_fn(batch, params, opt_state):\n",
        "  x, y = batch\n",
        "  loss, grad = log_posterior_wgrad_fn(params, (x, y))\n",
        "  updates, new_opt_state = optimizer.update(grad, opt_state)\n",
        "  new_params = optax.apply_updates(params, updates)\n",
        "  return new_params, new_opt_state, loss"
      ],
      "execution_count": 7,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "byrbWJFvVACJ"
      },
      "source": [
        "## Training\n",
        "\n",
        "We will run training for 5 epochs, which can take several minutes to complete. Note that in order to achieve good results you need to run the method substantially longer and tune the hyper-parameters."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "zzHlKEK1uKBS",
        "outputId": "40eb84e5-291d-49d4-a295-ef65a607a53d"
      },
      "source": [
        "for epoch in range(num_epochs):\n",
        "    sum_loss = 0.\n",
        "    for x, y in train_dataset_batched:\n",
        "        x, y = jnp.asarray(x), jnp.asarray(y)\n",
        "        params, opt_state, loss = update_fn((x, y), params, opt_state)\n",
        "        sum_loss += loss\n",
        "    \n",
        "    test_acc, all_test_probs = evaluate_fn(test_dataset_batched, params)\n",
        "    print(\"Epoch {}\".format(epoch))\n",
        "    print(\"\\tAverage loss: {}\".format(sum_loss / epoch_steps))\n",
        "    print(\"\\tTest accuracy: {}\".format(test_acc))\n",
        "\n",
        "all_test_probs = onp.asarray(all_test_probs)"
      ],
      "execution_count": 8,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Epoch 0\n",
            "\tAverage loss: -472672.8125\n",
            "\tTest accuracy: 0.29680001437664033\n",
            "Epoch 1\n",
            "\tAverage loss: -472569.15625\n",
            "\tTest accuracy: 0.35730001628398894\n",
            "Epoch 2\n",
            "\tAverage loss: -472546.53125\n",
            "\tTest accuracy: 0.41730002164840696\n",
            "Epoch 3\n",
            "\tAverage loss: -472528.9375\n",
            "\tTest accuracy: 0.4814000278711319\n",
            "Epoch 4\n",
            "\tAverage loss: -472512.15625\n",
            "\tTest accuracy: 0.5193000257015228\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QKDfvlHzxulX"
      },
      "source": [
        "## Evaluating metrics"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EiQcDfma00Qn"
      },
      "source": [
        "The starter kit comes with a script that can compute the agreement and total variation distance metrics used in the competition."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "XNr_UlytxxB8"
      },
      "source": [
        "import metrics"
      ],
      "execution_count": 9,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zu-gQzoY05P2"
      },
      "source": [
        "We can load the HMC reference predictions from the starter kit as well."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "gdIagtdhxyRu"
      },
      "source": [
        "with open('neurips_bdl_starter_kit/data/cifar10/probs.csv', 'r') as fp:\n",
        "  reference = onp.loadtxt(fp)"
      ],
      "execution_count": 10,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "80T57WFH09Eh"
      },
      "source": [
        "Now we can compute the metrics!"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "TCSC3Svbx-0M",
        "outputId": "8b1f7878-4caa-4741-cd50-c033a8d28346"
      },
      "source": [
        "metrics.agreement(all_test_probs, reference)"
      ],
      "execution_count": 11,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "0.5274"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 11
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "fJOvBOe0zmhT",
        "outputId": "e911ef29-670d-4781-ee39-85483b741e66"
      },
      "source": [
        "metrics.total_variation_distance(all_test_probs, reference)"
      ],
      "execution_count": 12,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "0.5339707105455799"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 12
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Y7nYa4w3VmnO"
      },
      "source": [
        "## Preparing the submission\n",
        "\n",
        "Once you run the code above, `all_test_probs` should contain an array of size `10000 x 10` where the rows correspond to test datapoints and columns correspond to classes."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "atqsSbjBInpg",
        "outputId": "80999871-f556-4656-c464-a7a2cb11433b"
      },
      "source": [
        "all_test_probs.shape"
      ],
      "execution_count": 13,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(10000, 10)"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 13
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FB2LhtfUW1ar"
      },
      "source": [
        "Now, we need to save the array as `cifar10_probs.csv` and create a zip archive with this file."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "h0CtqrJCI2-F",
        "outputId": "5d29b7e2-cac9-4cf3-9544-4a287a42691e"
      },
      "source": [
        "onp.savetxt(\"cifar10_probs.csv\", all_test_probs)\n",
        "\n",
        "!zip submission.zip cifar10_probs.csv"
      ],
      "execution_count": 14,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "  adding: cifar10_probs.csv (deflated 56%)\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "f-3deTsUNcqI"
      },
      "source": [
        "Finally, you can download the submission by running the code below. If the download doesn't start, check that your browser did not block it automatically."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 17
        },
        "id": "ieCUqCPjJP0k",
        "outputId": "5d4d1803-6b81-416e-d23a-e50ba82ea086"
      },
      "source": [
        "from google.colab import files\n",
        "files.download('submission.zip') "
      ],
      "execution_count": 15,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "application/javascript": [
              "\n",
              "    async function download(id, filename, size) {\n",
              "      if (!google.colab.kernel.accessAllowed) {\n",
              "        return;\n",
              "      }\n",
              "      const div = document.createElement('div');\n",
              "      const label = document.createElement('label');\n",
              "      label.textContent = `Downloading \"${filename}\": `;\n",
              "      div.appendChild(label);\n",
              "      const progress = document.createElement('progress');\n",
              "      progress.max = size;\n",
              "      div.appendChild(progress);\n",
              "      document.body.appendChild(div);\n",
              "\n",
              "      const buffers = [];\n",
              "      let downloaded = 0;\n",
              "\n",
              "      const channel = await google.colab.kernel.comms.open(id);\n",
              "      // Send a message to notify the kernel that we're ready.\n",
              "      channel.send({})\n",
              "\n",
              "      for await (const message of channel.messages) {\n",
              "        // Send a message to notify the kernel that we're ready.\n",
              "        channel.send({})\n",
              "        if (message.buffers) {\n",
              "          for (const buffer of message.buffers) {\n",
              "            buffers.push(buffer);\n",
              "            downloaded += buffer.byteLength;\n",
              "            progress.value = downloaded;\n",
              "          }\n",
              "        }\n",
              "      }\n",
              "      const blob = new Blob(buffers, {type: 'application/binary'});\n",
              "      const a = document.createElement('a');\n",
              "      a.href = window.URL.createObjectURL(blob);\n",
              "      a.download = filename;\n",
              "      div.appendChild(a);\n",
              "      a.click();\n",
              "      div.remove();\n",
              "    }\n",
              "  "
            ],
            "text/plain": [
              "<IPython.core.display.Javascript object>"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "display_data",
          "data": {
            "application/javascript": [
              "download(\"download_2831dbf0-34de-4d48-ae9f-c20b03a0700a\", \"submission.zip\", 1089933)"
            ],
            "text/plain": [
              "<IPython.core.display.Javascript object>"
            ]
          },
          "metadata": {
            "tags": []
          }
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "izLlUVM4XIWv"
      },
      "source": [
        "Now you can head over to the [submission system](https://competitions.codalab.org/competitions/33512?secret_key=10f23c1f-9c86-4a7a-8406-d85b0a0713f2#participate) and upload your submission. Good luck :)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iOP_7nqoz0VI"
      },
      "source": [
        "## Adding IMDB results\n",
        "\n",
        "Here, we run SGD on our CNN-LSTM model on IMDB. To keep the code simple, we copy the code above making the necessary modifications."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "NOSS0brwz_OD"
      },
      "source": [
        "x_train = onp.loadtxt(\"imdb_train_x.csv\").astype(int)\n",
        "y_train = onp.loadtxt(\"imdb_train_y.csv\")\n",
        "x_test = onp.loadtxt(\"imdb_test_x.csv\").astype(int)\n",
        "y_test = onp.loadtxt(\"imdb_test_y.csv\")\n",
        "\n",
        "\n",
        "train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))\n",
        "test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))"
      ],
      "execution_count": 28,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "nUyfAuU5z32b"
      },
      "source": [
        "net_apply, net_init = models.get_model(\"cnn_lstm\", data_info={\"num_classes\": 2})\n",
        "prior_variance = 1.\n",
        "\n",
        "def log_likelihood_fn(params, batch, is_training=True):\n",
        "    \"\"\"Computes the log-likelihood.\"\"\"\n",
        "    x, y = batch\n",
        "    logits = net_apply(params, None, x, is_training)\n",
        "    num_classes = logits.shape[-1]\n",
        "    labels = jax.nn.one_hot(y, num_classes)\n",
        "    softmax_xent = jnp.sum(labels * jax.nn.log_softmax(logits))\n",
        "\n",
        "    return softmax_xent\n",
        "\n",
        "\n",
        "def log_prior_fn(params):\n",
        "    \"\"\"Computes the Gaussian prior log-density.\"\"\"\n",
        "    n_params = sum([p.size for p in jax.tree_leaves(params)])\n",
        "    exp_term = sum(jax.tree_leaves(jax.tree_map(\n",
        "        lambda p: (-p**2 / (2 * prior_variance)).sum(), params)))\n",
        "    norm_constant = -0.5 * n_params * jnp.log((2 * math.pi * prior_variance))\n",
        "    return exp_term + norm_constant\n",
        "\n",
        "\n",
        "def log_posterior_fn(params, batch, is_training=True):\n",
        "    log_lik = log_likelihood_fn(params, batch, is_training=True)\n",
        "    log_prior = log_prior_fn(params)\n",
        "    return log_lik + log_prior\n",
        "    # return log_lik\n",
        "\n",
        "\n",
        "@jax.jit\n",
        "def get_accuracy_fn(batch, params):\n",
        "    x, y = batch\n",
        "    logits = net_apply(params, None, x, False)\n",
        "    probs = jax.nn.softmax(logits, axis=1)\n",
        "    preds = jnp.argmax(logits, axis=1)\n",
        "    accuracy = (preds == y).mean()\n",
        "    return accuracy, probs\n",
        "\n",
        "\n",
        "def evaluate_fn(dataset, params):\n",
        "    sum_accuracy = 0\n",
        "    all_probs = []\n",
        "    for x, y in dataset:\n",
        "        x, y = jnp.asarray(x), jnp.asarray(y)\n",
        "        batch_accuracy, batch_probs = get_accuracy_fn((x, y), params)\n",
        "        sum_accuracy += batch_accuracy.item()\n",
        "        all_probs.append(onp.asarray(batch_probs))\n",
        "    all_probs = jnp.concatenate(all_probs, axis=0)\n",
        "    return sum_accuracy / len(dataset), all_probs\n",
        "\n",
        "\n",
        "log_posterior_wgrad_fn = jax.jit(jax.value_and_grad(log_posterior_fn, argnums=0))"
      ],
      "execution_count": 29,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "y8PcJL5Y28k0"
      },
      "source": [
        "seed = jax.random.PRNGKey(0)\n",
        "key, net_init_key = jax.random.split(seed, 2)\n",
        "init_data, _ = next(iter(train_dataset))\n",
        "init_data = jnp.asarray(init_data)[None, :]\n",
        "params = net_init(net_init_key, init_data, True)"
      ],
      "execution_count": 47,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "5C7-IAkW0VJV"
      },
      "source": [
        "batch_size = 100\n",
        "test_batch_size = 1000\n",
        "num_epochs = 20\n",
        "shuffle_buffer_size = 1000\n",
        "momentum_decay = 0.9\n",
        "init_lr = 3.e-4\n",
        "\n",
        "train_dataset_batched = train_dataset.shuffle(shuffle_buffer_size, reshuffle_each_iteration=True)\n",
        "train_dataset_batched = train_dataset_batched.batch(batch_size)\n",
        "test_dataset_batched = test_dataset.batch(test_batch_size)\n",
        "\n",
        "epoch_steps = len(train_dataset_batched)\n",
        "total_steps = epoch_steps * num_epochs\n",
        "\n",
        "\n",
        "\n",
        "def lr_schedule(step):\n",
        "    t = step / total_steps\n",
        "    return 0.5 * init_lr * (1 + jnp.cos(t * onp.pi))\n",
        "\n",
        "optimizer = optax.chain(\n",
        "    optax.trace(decay=momentum_decay, nesterov=False),\n",
        "    optax.scale_by_schedule(lr_schedule))\n",
        "opt_state = optimizer.init(params)\n",
        "\n",
        "\n",
        "@jax.jit\n",
        "def update_fn(batch, params, opt_state):\n",
        "  x, y = batch\n",
        "  loss, grad = log_posterior_wgrad_fn(params, (x, y))\n",
        "  updates, new_opt_state = optimizer.update(grad, opt_state)\n",
        "  new_params = optax.apply_updates(params, updates)\n",
        "  return new_params, new_opt_state, loss"
      ],
      "execution_count": 48,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "_TyJ7ibZ2XmX",
        "outputId": "9631fc01-2113-4b06-a65a-7da3c013c82b"
      },
      "source": [
        "for epoch in range(num_epochs):\n",
        "    sum_loss = 0.\n",
        "    for x, y in train_dataset_batched:\n",
        "        x, y = jnp.asarray(x), jnp.asarray(y)\n",
        "        params, opt_state, loss = update_fn((x, y), params, opt_state)\n",
        "        sum_loss += loss\n",
        "    \n",
        "    test_acc, all_test_probs = evaluate_fn(test_dataset_batched, params)\n",
        "    print(\"Epoch {}\".format(epoch))\n",
        "    print(\"\\tAverage loss: {}\".format(sum_loss / epoch_steps))\n",
        "    print(\"\\tTest accuracy: {}\".format(test_acc))\n",
        "\n",
        "all_test_probs = onp.asarray(all_test_probs)"
      ],
      "execution_count": 49,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Epoch 0\n",
            "\tAverage loss: -3015065.75\n",
            "\tTest accuracy: 0.7025200343132019\n",
            "Epoch 1\n",
            "\tAverage loss: -2596647.75\n",
            "\tTest accuracy: 0.6938000321388245\n",
            "Epoch 2\n",
            "\tAverage loss: -2506782.75\n",
            "\tTest accuracy: 0.735680034160614\n",
            "Epoch 3\n",
            "\tAverage loss: -2487187.5\n",
            "\tTest accuracy: 0.7609600353240967\n",
            "Epoch 4\n",
            "\tAverage loss: -2482742.0\n",
            "\tTest accuracy: 0.8123200368881226\n",
            "Epoch 5\n",
            "\tAverage loss: -2481675.25\n",
            "\tTest accuracy: 0.7172400331497193\n",
            "Epoch 6\n",
            "\tAverage loss: -2481397.0\n",
            "\tTest accuracy: 0.8155200386047363\n",
            "Epoch 7\n",
            "\tAverage loss: -2481321.5\n",
            "\tTest accuracy: 0.8262000417709351\n",
            "Epoch 8\n",
            "\tAverage loss: -2481293.5\n",
            "\tTest accuracy: 0.809320044517517\n",
            "Epoch 9\n",
            "\tAverage loss: -2481283.5\n",
            "\tTest accuracy: 0.7610400342941284\n",
            "Epoch 10\n",
            "\tAverage loss: -2481282.75\n",
            "\tTest accuracy: 0.7161200308799743\n",
            "Epoch 11\n",
            "\tAverage loss: -2481282.0\n",
            "\tTest accuracy: 0.8052400350570679\n",
            "Epoch 12\n",
            "\tAverage loss: -2481279.75\n",
            "\tTest accuracy: 0.8369200444221496\n",
            "Epoch 13\n",
            "\tAverage loss: -2481278.25\n",
            "\tTest accuracy: 0.8426000404357911\n",
            "Epoch 14\n",
            "\tAverage loss: -2481276.75\n",
            "\tTest accuracy: 0.8396400427818298\n",
            "Epoch 15\n",
            "\tAverage loss: -2481274.0\n",
            "\tTest accuracy: 0.811960039138794\n",
            "Epoch 16\n",
            "\tAverage loss: -2481273.5\n",
            "\tTest accuracy: 0.8345200443267822\n",
            "Epoch 17\n",
            "\tAverage loss: -2481270.75\n",
            "\tTest accuracy: 0.8319200420379639\n",
            "Epoch 18\n",
            "\tAverage loss: -2481268.25\n",
            "\tTest accuracy: 0.8330400443077087\n",
            "Epoch 19\n",
            "\tAverage loss: -2481267.0\n",
            "\tTest accuracy: 0.8375200390815735\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "f1ZVIHZa5FfS"
      },
      "source": [
        "## Final submission\n",
        "\n",
        "Now we can combine the prediction files that we produced into a single submission zip."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 51
        },
        "id": "rL5JmXB82qbT",
        "outputId": "80434232-a2b1-4e90-81d7-85348433cef4"
      },
      "source": [
        "onp.savetxt(\"imdb_probs.csv\", all_test_probs)\n",
        "\n",
        "!zip submission.zip cifar10_probs.csv imdb_probs.csv\n",
        "files.download('submission.zip') "
      ],
      "execution_count": 52,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "updating: cifar10_probs.csv (deflated 56%)\n",
            "updating: imdb_probs.csv (deflated 58%)\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "display_data",
          "data": {
            "application/javascript": [
              "\n",
              "    async function download(id, filename, size) {\n",
              "      if (!google.colab.kernel.accessAllowed) {\n",
              "        return;\n",
              "      }\n",
              "      const div = document.createElement('div');\n",
              "      const label = document.createElement('label');\n",
              "      label.textContent = `Downloading \"${filename}\": `;\n",
              "      div.appendChild(label);\n",
              "      const progress = document.createElement('progress');\n",
              "      progress.max = size;\n",
              "      div.appendChild(progress);\n",
              "      document.body.appendChild(div);\n",
              "\n",
              "      const buffers = [];\n",
              "      let downloaded = 0;\n",
              "\n",
              "      const channel = await google.colab.kernel.comms.open(id);\n",
              "      // Send a message to notify the kernel that we're ready.\n",
              "      channel.send({})\n",
              "\n",
              "      for await (const message of channel.messages) {\n",
              "        // Send a message to notify the kernel that we're ready.\n",
              "        channel.send({})\n",
              "        if (message.buffers) {\n",
              "          for (const buffer of message.buffers) {\n",
              "            buffers.push(buffer);\n",
              "            downloaded += buffer.byteLength;\n",
              "            progress.value = downloaded;\n",
              "          }\n",
              "        }\n",
              "      }\n",
              "      const blob = new Blob(buffers, {type: 'application/binary'});\n",
              "      const a = document.createElement('a');\n",
              "      a.href = window.URL.createObjectURL(blob);\n",
              "      a.download = filename;\n",
              "      div.appendChild(a);\n",
              "      a.click();\n",
              "      div.remove();\n",
              "    }\n",
              "  "
            ],
            "text/plain": [
              "<IPython.core.display.Javascript object>"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "display_data",
          "data": {
            "application/javascript": [
              "download(\"download_2c61f50f-c3b5-48fa-a8e7-2ba0ed6569ee\", \"submission.zip\", 1614064)"
            ],
            "text/plain": [
              "<IPython.core.display.Javascript object>"
            ]
          },
          "metadata": {
            "tags": []
          }
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ytDtGrCL5YCA"
      },
      "source": [
        "Now, let us head over to the [submission platform](https://competitions.codalab.org/competitions/33647) and upload the submission."
      ]
    }
  ]
}